from hpccm.primitives import baseimage
from hpccm.building_blocks import *

Stage0 += baseimage(image="nvidia/cuda:12.8.1-devel-ubuntu22.04")
#Stage0 += baseimage(image="/CLQCD/containers/update_cuda11.8_ubuntu22.04.sif",_bootstrap = "localimage")


Stage0 += gnu()

Stage0 += mlnx_ofed()
Stage0 += gdrcopy()
Stage0 += knem()
Stage0 += xpmem()
Stage0 += ucx(version="1.18.0")
Stage0 += openmpi(infiniband=False, ucx=True, version="4.1.8")

Stage0 += hdf5(enable_parallel=True, enable_shared=False)

Stage0 += cmake(eula=True)
Stage0 += generic_cmake(
    branch="develop",
    cmake_opts=[
        "-D CMAKE_BUILD_TYPE=RELEASE",
        "-D enable_shared=True",
        "-D QUDA_GPU_ARCH=sm_60",
        "-D QUDA_MULTIGRID=ON",
        "-D QUDA_COVDEV=ON",
        "-D QUDA_CLOVER_DYNAMIC=OFF",
        "-D QUDA_CLOVER_RECONSTRUCT=OFF",
        "-D QUDA_DIRAC_DEFAULT_OFF=ON",
        "-D QUDA_DIRAC_WILSON=ON",
        "-D QUDA_DIRAC_CLOVER=ON",
        "-D QUDA_DIRAC_STAGGERED=ON",
        "-D QUDA_DIRAC_LAPLACE=ON",
        "-D QUDA_MPI=ON",
    ],
    prefix="/usr/local/quda",
    repository="https://github.com/lattice/quda.git",
)
